{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(batch_inference_overview)=\n",
    "# Batch inference\n",
    "\n",
    "Batch inference or offline inference addresses the need to run machine learning model on large datasets. It is the process of generating outputs on a batch of observations.\n",
    "\n",
    "With batch inference, the batch runs are typically generated during some recurring schedule (e.g., hourly, or daily). These inferences are then stored in a database or a file and can be made available to developers or end users. With batch inference, the goal is usually tied to time constraints and the service-level agreement (SLA) of the job. Conversely, in real time serving, the goal is usually to optimize the number of transactions per second that the model can process. An online application displays a result to the user.\n",
    "\n",
    "Batch inference can sometimes take advantage of big data technologies such as Spark to generate predictions. Big data technologies allows data scientists and machine learning engineers to take advantage of scalable compute resources to generate many predictions at once."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test your model\n",
    "\n",
    "To evaluate batch model prior to deployment, you should use the `evaluate` handler of the `auto_trainer` function.\n",
    "\n",
    "This is typically done during model development. For more information refer to the {ref}`auto_trainer_evaluate` handler documentation. For example:\n",
    "``` python\n",
    "import mlrun\n",
    "\n",
    "# Set the base project name\n",
    "project_name_base = 'batch-inference'\n",
    "\n",
    "# Initialize the MLRun project object\n",
    "project = mlrun.get_or_create_project(project_name_base, context=\"./\", user_project=True)\n",
    "\n",
    "auto_trainer = project.set_function(mlrun.import_function(\"hub://auto_trainer\"))\n",
    "\n",
    "evaluate_run = project.run_function(\n",
    "    auto_trainer,\n",
    "    handler=\"evaluate\",\n",
    "    inputs={\"dataset\": train_run.outputs['test_set']},\n",
    "    params={\n",
    "        \"model\": train_run.outputs['model'],\n",
    "        \"label_columns\": \"labels\",\n",
    "    },\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deploy your model\n",
    "\n",
    "Batch inference is implemented in MLRun by running the function with an input dataset. With MLRun you can easily create any custom logic in a function, including loading a model and calling it.\n",
    "\n",
    "The Function Hub [batch inference function](https://github.com/mlrun/functions/tree/development/batch_inference) is used for running the models in batch as well as performing drift analysis. The function supports the following frameworks:\n",
    "\n",
    "- Scikit-learn\n",
    "- XGBoost\n",
    "- LightGBM  \n",
    "- Tensorflow/Keras\n",
    "- PyTorch\n",
    "- ONNX\n",
    "\n",
    "Internally the function uses MLRun's out-of-the-box capability to load run a model via the {py:class}`mlrun.frameworks.auto_mlrun.auto_mlrun.AutoMLRun` class."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Basic example\n",
    "The simplest example to run the function is as follows:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create project\n",
    "\n",
    "Import MLRun and create a project:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mlrun\n",
    "\n",
    "project = mlrun.get_or_create_project(\n",
    "    \"batch-inference\", context=\"./\", user_project=True\n",
    ")\n",
    "batch_inference = mlrun.import_function(\"hub://batch_inference\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Get the model\n",
    "\n",
    "Get the model. The model is a [decision tree classifier](https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html) from scikit-learn. Note that if you previously trained your model using MLRun, you can reference the model artifact produced during that training process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = mlrun.get_sample_path(\"models/batch-predict/model.pkl\")\n",
    "\n",
    "model_artifact = project.log_model(\n",
    "    key=\"model\", model_file=model_path, framework=\"sklearn\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Get the data\n",
    "\n",
    "Get the dataset to perform the inference. The dataset is in `parquet` format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction_set_path = mlrun.get_sample_path(\"data/batch-predict/prediction_set.parquet\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run the batch inference function    \n",
    "\n",
    "Run the inference. In the first example we will not perform any drift analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_run = project.run_function(\n",
    "    batch_inference,\n",
    "    inputs={\"dataset\": prediction_set_path},\n",
    "    params={\"model\": model_artifact.uri},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Function output\n",
    "\n",
    "The output of the function is an artifact called `prediction`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>feature_0</th>\n",
       "      <th>feature_1</th>\n",
       "      <th>feature_2</th>\n",
       "      <th>feature_3</th>\n",
       "      <th>feature_4</th>\n",
       "      <th>feature_5</th>\n",
       "      <th>feature_6</th>\n",
       "      <th>feature_7</th>\n",
       "      <th>feature_8</th>\n",
       "      <th>feature_9</th>\n",
       "      <th>...</th>\n",
       "      <th>feature_11</th>\n",
       "      <th>feature_12</th>\n",
       "      <th>feature_13</th>\n",
       "      <th>feature_14</th>\n",
       "      <th>feature_15</th>\n",
       "      <th>feature_16</th>\n",
       "      <th>feature_17</th>\n",
       "      <th>feature_18</th>\n",
       "      <th>feature_19</th>\n",
       "      <th>predicted_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>-2.059506</td>\n",
       "      <td>-1.314291</td>\n",
       "      <td>2.721516</td>\n",
       "      <td>-2.132869</td>\n",
       "      <td>-0.693963</td>\n",
       "      <td>0.376643</td>\n",
       "      <td>3.017790</td>\n",
       "      <td>3.876329</td>\n",
       "      <td>-1.294736</td>\n",
       "      <td>0.030773</td>\n",
       "      <td>...</td>\n",
       "      <td>2.775699</td>\n",
       "      <td>2.361580</td>\n",
       "      <td>0.173441</td>\n",
       "      <td>0.879510</td>\n",
       "      <td>1.141007</td>\n",
       "      <td>4.608280</td>\n",
       "      <td>-0.518388</td>\n",
       "      <td>0.129690</td>\n",
       "      <td>2.794967</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-1.190382</td>\n",
       "      <td>0.891571</td>\n",
       "      <td>3.726070</td>\n",
       "      <td>0.673870</td>\n",
       "      <td>-0.252565</td>\n",
       "      <td>-0.729156</td>\n",
       "      <td>2.646563</td>\n",
       "      <td>4.782729</td>\n",
       "      <td>0.318952</td>\n",
       "      <td>-0.781567</td>\n",
       "      <td>...</td>\n",
       "      <td>1.101721</td>\n",
       "      <td>3.723400</td>\n",
       "      <td>-0.466867</td>\n",
       "      <td>-0.056224</td>\n",
       "      <td>3.344701</td>\n",
       "      <td>0.194332</td>\n",
       "      <td>0.463992</td>\n",
       "      <td>0.292268</td>\n",
       "      <td>4.665876</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-0.996384</td>\n",
       "      <td>-0.099537</td>\n",
       "      <td>3.421476</td>\n",
       "      <td>0.162771</td>\n",
       "      <td>-1.143458</td>\n",
       "      <td>-1.026791</td>\n",
       "      <td>2.114702</td>\n",
       "      <td>2.517553</td>\n",
       "      <td>-0.154620</td>\n",
       "      <td>-0.465423</td>\n",
       "      <td>...</td>\n",
       "      <td>1.729386</td>\n",
       "      <td>2.820340</td>\n",
       "      <td>-1.041428</td>\n",
       "      <td>-0.331871</td>\n",
       "      <td>2.909172</td>\n",
       "      <td>2.138613</td>\n",
       "      <td>-0.046252</td>\n",
       "      <td>-0.732631</td>\n",
       "      <td>4.716266</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-0.289976</td>\n",
       "      <td>-1.680019</td>\n",
       "      <td>3.126478</td>\n",
       "      <td>-0.704451</td>\n",
       "      <td>-1.149112</td>\n",
       "      <td>1.174962</td>\n",
       "      <td>2.860341</td>\n",
       "      <td>3.753661</td>\n",
       "      <td>-0.326119</td>\n",
       "      <td>2.128411</td>\n",
       "      <td>...</td>\n",
       "      <td>2.328688</td>\n",
       "      <td>3.397321</td>\n",
       "      <td>-0.932060</td>\n",
       "      <td>-1.442370</td>\n",
       "      <td>2.058517</td>\n",
       "      <td>3.881936</td>\n",
       "      <td>2.090635</td>\n",
       "      <td>-0.045832</td>\n",
       "      <td>4.197315</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-0.294866</td>\n",
       "      <td>1.044919</td>\n",
       "      <td>2.924139</td>\n",
       "      <td>0.814049</td>\n",
       "      <td>-1.455054</td>\n",
       "      <td>-0.270432</td>\n",
       "      <td>3.380195</td>\n",
       "      <td>2.339669</td>\n",
       "      <td>1.029101</td>\n",
       "      <td>-1.171018</td>\n",
       "      <td>...</td>\n",
       "      <td>1.283565</td>\n",
       "      <td>0.677006</td>\n",
       "      <td>-2.147444</td>\n",
       "      <td>-0.494150</td>\n",
       "      <td>3.222041</td>\n",
       "      <td>6.219348</td>\n",
       "      <td>-1.914110</td>\n",
       "      <td>0.317786</td>\n",
       "      <td>4.143443</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 21 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   feature_0  feature_1  feature_2  feature_3  feature_4  feature_5  \\\n",
       "0  -2.059506  -1.314291   2.721516  -2.132869  -0.693963   0.376643   \n",
       "1  -1.190382   0.891571   3.726070   0.673870  -0.252565  -0.729156   \n",
       "2  -0.996384  -0.099537   3.421476   0.162771  -1.143458  -1.026791   \n",
       "3  -0.289976  -1.680019   3.126478  -0.704451  -1.149112   1.174962   \n",
       "4  -0.294866   1.044919   2.924139   0.814049  -1.455054  -0.270432   \n",
       "\n",
       "   feature_6  feature_7  feature_8  feature_9  ...  feature_11  feature_12  \\\n",
       "0   3.017790   3.876329  -1.294736   0.030773  ...    2.775699    2.361580   \n",
       "1   2.646563   4.782729   0.318952  -0.781567  ...    1.101721    3.723400   \n",
       "2   2.114702   2.517553  -0.154620  -0.465423  ...    1.729386    2.820340   \n",
       "3   2.860341   3.753661  -0.326119   2.128411  ...    2.328688    3.397321   \n",
       "4   3.380195   2.339669   1.029101  -1.171018  ...    1.283565    0.677006   \n",
       "\n",
       "   feature_13  feature_14  feature_15  feature_16  feature_17  feature_18  \\\n",
       "0    0.173441    0.879510    1.141007    4.608280   -0.518388    0.129690   \n",
       "1   -0.466867   -0.056224    3.344701    0.194332    0.463992    0.292268   \n",
       "2   -1.041428   -0.331871    2.909172    2.138613   -0.046252   -0.732631   \n",
       "3   -0.932060   -1.442370    2.058517    3.881936    2.090635   -0.045832   \n",
       "4   -2.147444   -0.494150    3.222041    6.219348   -1.914110    0.317786   \n",
       "\n",
       "   feature_19  predicted_label  \n",
       "0    2.794967                0  \n",
       "1    4.665876                0  \n",
       "2    4.716266                0  \n",
       "3    4.197315                1  \n",
       "4    4.143443                1  \n",
       "\n",
       "[5 rows x 21 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_run.artifact(\"prediction\").as_df().head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### View the results in the UI\n",
    "\n",
    "The output is saved as a parquet file under the project artifact path. In the UI you can go to the `batch-inference-infer` job --> artifact tab to view the details.\n",
    "\n",
    "![batch prediction results](../_static/images/batch_inference_prediction_artifact.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Scheduling a batch run\n",
    "\n",
    "To schedule a run, you can set the schedule parameter of the run method. The scheduling is done by using a cron format.\n",
    "\n",
    "You can also schedule runs from the dashboard. On the Projects > Jobs and Workflows page, you can create a new job using the New Job wizard. At the end of the wizard, you can set the job scheduling. In the following example, the job is set to run every 30 minutes.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_run = project.run_function(\n",
    "    batch_inference,\n",
    "    inputs={\"dataset\": prediction_set_path},\n",
    "    params={\"model\": model_artifact.uri},\n",
    "    schedule=\"*/30 * * * *\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Drift analysis\n",
    "\n",
    "By default, if a model has a sample set statistics, `batch_inference` performs drift analysis and will produce a data drift table artifact, as well as numerical drift metrics.\n",
    "\n",
    "To provide sample set statistics for the model you can either:\n",
    "\n",
    "1. Train the model using MLRun. This allows you to create the sample set during training.\n",
    "2. Log an external model using `project.log_model` method and provide the training set in the `training_set` parameter.\n",
    "3. Provide the set explicitly when calling the `batch_inference` function via the `sample_set` input.\n",
    "\n",
    "In the example below, we will provide the training set as the sample set\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_set_path = mlrun.get_sample_path(\"data/batch-predict/training_set.parquet\")\n",
    "\n",
    "batch_run = project.run_function(\n",
    "    batch_inference,\n",
    "    inputs={\"dataset\": prediction_set_path, \"sample_set\": training_set_path},\n",
    "    params={\n",
    "        \"model\": model_artifact.uri,\n",
    "        \"label_columns\": \"label\",\n",
    "        \"perform_drift_analysis\": True,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this case, instead of just prediction, you will get drift analysis. The drift table plot that compares the drift between the training data and prediction data per feature:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_run.artifact(\"drift_table_plot\").show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![batch inference drift table plot](../_static/images/tutorial/drift_table_plot.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You also get a numerical drift metric and boolean flag denoting whether or not data drift is detected:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'drift_status': False, 'drift_metric': 0.29934242566253266}\n"
     ]
    }
   ],
   "source": [
    "print(batch_run.status.results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'feature_0': 0.028086840976606773,\n",
       " 'feature_1': 0.04485072701663093,\n",
       " 'feature_2': 0.7391279921664593,\n",
       " 'feature_3': 0.043769819014849734,\n",
       " 'feature_4': 0.042755641152500176,\n",
       " 'feature_5': 0.05184219833790496,\n",
       " 'feature_6': 0.7262042202197605,\n",
       " 'feature_7': 0.7297906294873706,\n",
       " 'feature_8': 0.039060131873550404,\n",
       " 'feature_9': 0.04468363504674985,\n",
       " 'feature_10': 0.042567035578799796,\n",
       " 'feature_11': 0.7221431701127441,\n",
       " 'feature_12': 0.7034787615778625,\n",
       " 'feature_13': 0.04239724655474124,\n",
       " 'feature_14': 0.046364723781764774,\n",
       " 'feature_15': 0.6329075683793959,\n",
       " 'feature_16': 0.7181622588902428,\n",
       " 'feature_17': 0.03587785749574268,\n",
       " 'feature_18': 0.04443732609382538,\n",
       " 'feature_19': 0.7902698698155215,\n",
       " 'label': 0.017413285340161608}"
      ]
     },
     "metadata": {
      "application/json": {
       "expanded": false,
       "root": "root"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Data/concept drift per feature (use batch_run.artifact(\"features_drift_results\").get() to obtain the raw data)\n",
    "batch_run.artifact(\"features_drift_results\").show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `batch_inference` Parameters\n",
    "\n",
    "**Model Parameters**\n",
    "\n",
    "* `model`: `str` &mdash; The model store path.\n",
    "\n",
    "**Inference parameters**\n",
    "\n",
    "*Parameters to specify the dataset for inference.*\n",
    "\n",
    "* `dataset`: `DatasetType` &mdash; The dataset to infer through the model.\n",
    "   Can be passed in `inputs` as either a Dataset artifact / Feature vector URI or\n",
    "   in `parameters` as a list, dictionary or numpy array.\n",
    "* `drop_columns`: `Union[str, int, List[str], List[int]]` &mdash; A string / integer or a list of strings / integers that\n",
    "   represent the column names / indices to drop. When the dataset is a list or a numpy array this parameter must\n",
    "   be represented by integers.\n",
    "* `label_columns`: `Union[str, List[str]]` &mdash; The target label(s) of the column(s) in the dataset for Regression\n",
    "   or classification tasks. The label column can be accessed from the model object, or the feature vector provided\n",
    "   if available.\n",
    "* `predict_kwargs`: `Dict[str, Any]` &mdash; Additional parameters to pass to the prediction of the model.\n",
    "\n",
    "**Drift parameters**\n",
    "\n",
    "*Parameters that affect the drift calculation.*\n",
    "\n",
    "* `perform_drift_analysis`: `bool` = `None` &mdash; Whether to perform drift analysis between the sample set of the\n",
    "   model object to the dataset given. By default, None, which means it will perform drift analysis if the model has\n",
    "   a sample set statistics. Perform drift analysis will produce a data drift table artifact.\n",
    "* `sample_set`: `DatasetType` &mdash; A sample dataset to give to compare the inputs in the drift analysis. The\n",
    "   default chosen sample set will always be the one who is set in the model artifact itself.\n",
    "* `drift_threshold`: `float` = `0.7` &mdash; The threshold of which to mark drifts. Default is 0.7.\n",
    "* `possible_drift_threshold`: `float` = `0.5` &mdash; The threshold of which to mark possible drifts. Default is 0.5.\n",
    "* `inf_capping`: `float` = `10.0` &mdash; The value to set for when it reached infinity. Default is 10.0.\n",
    "\n",
    "**Logging parameters**\n",
    "\n",
    "*Parameters to control the automatic logging feature of MLRun. You can adjust the logging outputs as relevant and if\n",
    "not passed, a default list of artifacts and metrics is produced and calculated.*\n",
    "\n",
    "* `log_result_set`: `bool` = `True` &mdash; Whether to log the result set - a DataFrame of the given inputs concatenated\n",
    "   with the predictions. Default is True.\n",
    "* `result_set_name`: `str` = `\"prediction\"` &mdash; The db key to set name of the prediction result and the filename\n",
    "   Default is 'prediction'.\n",
    "* `artifacts_tag`: `str` &mdash; Tag to use for all the artifacts resulted from the function.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
